Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface functions to allow replacing the log density function and replacing AD wrapper type #33

Closed
wants to merge 3 commits into from

Conversation

sunxd3
Copy link

@sunxd3 sunxd3 commented Jul 11, 2024

Ref #32 (comment)

Brief summary:

  • added replace_ℓ interface function
  • if ADgradient take in a ADGradientWrapper, then recreate a new gradient wrapper with its log density function

I only added some implementations for ReverseDiff.

This is very much a draft right now, everything is up to modify.

@sunxd3
Copy link
Author

sunxd3 commented Jul 11, 2024

Copy link
Owner

@tpapp tpapp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also see comment in discussion.

Project.toml Show resolved Hide resolved
src/LogDensityProblemsAD.jl Show resolved Hide resolved
src/LogDensityProblemsAD.jl Outdated Show resolved Hide resolved
@tpapp
Copy link
Owner

tpapp commented Jul 12, 2024

It is unfortunate that the ADgradient constructor takes keywords, not structs, for the legacy interface, since if the gradient spec was all wrapped in a single container then we could just do

ADgradient(get_AD(ℓ), new_ℓ)

and just implement get_AD instead. If we wait for #29 then we could have that kind of API instead of replace_ℓ.

@sunxd3
Copy link
Author

sunxd3 commented Jul 12, 2024

If we wait for #29 then we could have that kind of API instead of replace_ℓ

It is cleaner. We can opt for it after the PR is merged.

@tpapp
Copy link
Owner

tpapp commented Jul 12, 2024

But then we would have to change the interface again...

I am inclined to go with replace_ℓ for now, with a note saying that it is experimental API at the moment and may just change. So the relevant PR in Turing could proceed.

But will wait to hear from @devmotion.

@devmotion
Copy link
Collaborator

My impression from TuringLang/Turing.jl#2231 (comment) and related comments in Turing.jl was that there's no clear need for such an API currently? One reason for such an API would be a case where calling ADgradient from scratch would be less efficient than a dedicated replace_l function (BTW IMO probably an official API - even if it is experimental - should use a non-Unicode name, potentially with a Unicode alias (but an alias seems a bit much for such a simple functionality)). But at least for the ReverseDiff example here there's no efficiency gain?

Regarding the implementation: Couldn't we achieve this functionality by overloading setproperty!?

@torfjelde
Copy link
Contributor

It is unfortunate that the ADgradient constructor takes keywords, not structs, for the legacy interface, since if the gradient spec was all wrapped in a single container then we could just do

Does the ADTypes.jl extension not effectively solve this? Or are there some kwargs that are still missing from the ADTypes.jl structs?

@sunxd3
Copy link
Author

sunxd3 commented Jul 16, 2024

I have a new proposal: add an interface function getADtype (or some other better name) and don't add the interface function this PR is trying to introduce. getADtype should return a ADTypes.ADType. Then packages can just use ADgradient with ADType to create the wrapper.

EDIT: just realized this is exactly what @tpapp was suggesting 👍

The motivation is that I don't think replace_l would be enough. At least for ReverseDiff, one failure mode is that the tape compiled without specifying input (i.e. kwarg x) can result in a tape that is not correct for all inputs (something related to control flow maybe? @yebai). In that case, we really need the ability to call ADgradient with kwargs.

@tpapp
Copy link
Owner

tpapp commented Jul 17, 2024

Sorry for the late responses, I am on holiday with limited net access.

@torfjelde: the problem is that not all the API is using ADtypes.

@sunxd3: yes, the cleanest solution would be that, see my comment above. But we need to clean up the API first.

I am not sure how pressing is the need for this solution. We could introduce something interim that solves the problem for Turing, with the understanding that it is internal and would be removed once we solve this.

@sunxd3
Copy link
Author

sunxd3 commented Jul 17, 2024

@tpapp understood your point now.

There are motivations to introduce such an API. Correctness of ReverseDiff's tape is one. Do we want to wait for #29 or maybe introduce something like getADtype, which is just one function that returns the ADType if supported?

@torfjelde
Copy link
Contributor

the problem is that not all the API is using ADtypes.

Gotcha 👍

I am not sure how pressing is the need for this solution. We could introduce something interim that solves the problem for Turing, with the understanding that it is internal and would be removed once we solve this.

We have a work-around on our side, so I think it's less pressing atm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants